%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Codes for Example 3.2 
% ETKF with Belanger's method, implemented on a 3-dim SDE
% Created by John Harlim 
% Last edited: March 16, 2018  
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

clear all
tic
load triad

n2 = 3;                 % dimension of the state
n = 7;                  % dimension of the state and parameters
m = 3;                  % dimension of the observations
TCYC = 10000;            % total assimilation cycle
H = zeros(m,n);         % linear observation operator
H(1:m,1:m) = eye(m);    
R = 0.1*var(x(1,:));    % observation noise covariance
L = 4;                  % lag parameter
En = 10;               % ensemble size
alpha = [sigma(1,1)^2*DT R]; % true noise parameters
p = length(alpha);      % number of stochastic parameters 
q = 2;                  % noise dimension

% Generate observations
y = x(1:m,[1:TCYC+1]) + sqrt(R)*randn(m,TCYC+1);

xa = zeros(n,En,TCYC);      % posterior ensemble
xb = zeros(n,En,TCYC+1);    % prior ensemble
ahat = zeros(p,TCYC);       % stochastic parameters

% specify initial conditions
xbbar = zeros(n,1);
xbbar(1:3,1) = 1;
xbbar(4,1) = 1; %omega
xbbar(5,1) = 2; %beta
xbbar(6,1) = 1; %gamma
xbbar(7,1) = 1; %a
ahat(:,1) = [1 1];
RM = randn(n,En);
RM2 = RM - mean(RM,2)*ones(1,En);
[EF,ES] = eig(RM2*RM2'/(En-1));
RM3 = EF*diag(1./sqrt(diag(ES)))*EF'*RM2;
xb(:,:,1) = xbbar(:,1)*ones(1,En) + RM3;
X = xb(:,:,1) - xbbar(:,1)*ones(1,En);


Gamma = zeros(n,q);         % Gamma
Gamma(2:3,1:2) = eye(2);
Qi = zeros(q,q,p);          % Ansatz for system noise covariance
Qi(:,:,1) = eye(q);
Ri = zeros(m,m,p);          % Ansats for observation noise covariance
Ri(1:m,1:m,2) = eye(m);

d = zeros(m,L+1);           % innovations
K = zeros(n,m);             % Kalman gain
% below are variables related to the secondary filter
phi = zeros(n,n);           
S = zeros(n,n,p);
M = zeros(n,m,p,L+1);
Mold = zeros(n,m,p,L+1);
F = zeros(m,m,p,L+1);
ssigma = zeros(m^2,1);
Fscript = zeros(m^2,p);
W = zeros(m^2,m^2);
Th = eye(p);
Evv = zeros(m,m,L+1);


%% index for specifying matrix W
for i = 1:m*m
    ka(i) = mod(i,m);
    if (ka(i)==0)
        ka(i) = m;
    end
    el(i) = ceil(i/m);
end

in =zeros(1,m*(m+1)/2);
ij =1;
for j=1:m
    for i = j:m
        in(ij) = (j-1)*m+i;
        ij =ij+1;
    end
end
Lin = length(in);
Hs0 = zeros(Lin,p);
Kgs0 = zeros(p,Lin);
Hs = zeros(m^2,p);
Kgs = zeros(p,m^2);

In1 = zeros(Lin,Lin);
In2 = zeros(Lin,Lin);
In3 = zeros(Lin,Lin);
In4 = zeros(Lin,Lin);

for j = 1:Lin           
    for i = 1:Lin
        In1(i,j)  = sub2ind([m m],ka(in(i)),ka(in(j)));
        In2(i,j)  = sub2ind([m m],el(in(i)),el(in(j)));
        In3(i,j)  = sub2ind([m m],ka(in(i)),el(in(j)));
        In4(i,j)  = sub2ind([m m],el(in(i)),ka(in(j)));                       
    end    
end
 
In5 = zeros(m^2,m^2);
In6 = zeros(m^2,m^2);

for j = 1:m^2           
    for i = 1:m^2
        In5(i,j) = sub2ind([m m],ka(i),ka(j));
        In6(i,j) = sub2ind([m m],el(i),el(j));
    end
end

%% begin filtering

for k=1:TCYC    
    % primary filter with ETKF
    d(:,1:L) = d(:,2:L+1);    
    d(:,L+1) = y(:,k) - H*xbbar;
    Y = H*X;
    Rtemp = 0;
    for j=1:p
        Rtemp = Rtemp + Ri(:,:,j)*ahat(j,k);
    end
        
    JJ = (En-1)*eye(En) + Y'*(Rtemp\Y);
    [U,SS] = svd(JJ);
    Kc = X*U*pinv(SS)*U'*Y'/Rtemp;
    TT = sqrt(En-1)*U*diag(1./sqrt(diag(SS)))*U';
    Xplus = X*TT;
    xabar = xbbar + Kc*d(:,L+1);
    xa(:,:,k) = xabar*ones(1,En)+Xplus;
    
    % deterministic ensemble forecast
    xb(:,:,k+1) = xa(:,:,k) + DT*triadparam(xa(:,:,k));
         
    % ensemble approximation for A
    xbbar = mean(xb(:,:,k+1),2);
    X = xb(:,:,k+1) - xbbar*ones(1,En);
    A = X*pinv(Xplus);
 
    % collect the estimated Q
    Qtemp = 0;
    for j=1:p
        Qtemp = Qtemp + Qi(:,:,j)*ahat(j,k);
    end
    Pb = X*X'/(En-1) + Gamma*Qtemp*Gamma';
            
    % generate random ensemble with mean zero and variance Pb
    RM = randn(n,En);
    RM2 = RM - mean(RM,2)*ones(1,En);
    [EF,ES] = eig(RM2*RM2'/(En-1));
    X = sqrtm(Pb)*EF*diag(1./sqrt(diag(ES)))*EF'*RM2;
                        
    % update K and phi
    Kold = K;
    K = A*Kc;
    phiold = phi;
    phi = A - K*H;
       
    % construct the secondary observation operator F
    for i = 1:p
        Mold(:,:,i,:) = M(:,:,i,:);        
        M(:,:,i,1) =  S(:,:,i)*H';
        F(:,:,i,1) = H*M(:,:,i,1) + Ri(:,:,i);
        S(:,:,i) =  phi*S(:,:,i)*phi'+Gamma*Qi(:,:,i)*Gamma'+K*Ri(:,:,i)*K';
        
        % Lag observation operator
        if (k>1)
            M(:,:,i,2) = phiold*Mold(:,:,i,1)-Kold*Ri(:,:,i);
            if (L>1)
                for l = 3:L+1
                    M(:,:,i,l) = phiold*Mold(:,:,i,l-1);
                end
            end
            for l = 2:L+1
                F(:,:,i,l) = H*M(:,:,i,l);
            end
        end
    end
    
    % Covariance E[v(k)v(k)'] to be used for constructing fourth order statistics    
    Evv(:,:,1:L) = Evv(:,:,2:L+1);    
    Evv(:,:,L+1) = ahat(1,k)*F(:,:,1,1);    
    for i=2:p
            Evv(:,:,L+1) = Evv(:,:,L+1) + ahat(i,k)*F(:,:,i,1);
    end

   
    % determining lag for k>1    
    if (k<2*L)                
        LT = 1;
    else
        LT = L+1;
    end
    
    meana = zeros(p,LT+1);
    Theta = zeros(p,p,LT+1);
    meana(:,1) = ahat(:,k);
    Theta(:,:,1) = Th;
    
    
    
    for l = 1:LT
                        
        % vectorizing the residual^2                   
        if (l==1)
            
            EVV = squeeze(Evv(:,:,L+1));
            W = real(EVV(In1).*EVV(In2) +EVV(In3).*EVV(In4));
            
                                                
            temp = reshape(d(:,L+1)*d(:,L+1-l+1)',m*m,1);
            ssigma(1:Lin,1) = temp(in); 
           
            
            for i=1:p    
                temp = reshape(F(:,:,i,l),m*m,1);
                Fscript(1:Lin,i) = temp(in); 
            end
            
            if (k>=1)            
                Hs0 = Fscript(1:Lin,:);
                Kgs0 = Theta(:,:,l)*Hs0'/(Hs0*Theta(:,:,l)*Hs0'+W(1:Lin,1:Lin));
                Theta(:,:,l+1) = (eye(p)-Kgs0*Hs0)*Theta(:,:,l);
                meana(:,l+1) = meana(:,l) + Kgs0*(ssigma(1:Lin,1)-Hs0*meana(:,l));
            else
                Theta(:,:,l+1) = Theta(:,:,l);
                meana(:,l+1) = meana(:,l);
            end 
            
            
            if isnan(temp)                
                return
            end
        else
            
            EVV = squeeze(Evv(:,:,L+1));
            EVV2 = squeeze(Evv(:,:,L+1-l+1));
            W = real(EVV(In5).*EVV2(In6));
            
            ssigma = reshape(d(:,L+1)*d(:,L+1-l+1)',m*m,1);
            
            for i=1:p    
                Fscript(:,i) = reshape(F(:,:,i,l),m*m,1);       
            end
            
            Hs = Fscript;
            Kgs = Theta(:,:,l)*Hs'/(Hs*Theta(:,:,l)*Hs'+W);
            Theta(:,:,l+1) = (eye(p)-Kgs*Hs)*Theta(:,:,l);
            meana(:,l+1) = meana(:,l) + Kgs*(ssigma-Hs*meana(:,l));
            
            
                        
            if isnan(temp)
                k
                return
            end
        end
    
    end
    ahat(:,k+1) = max(meana(:,LT+1),0);    
    Th = Theta(:,:,LT+1);
    
end
grey = [0.4, 0.4, 0.4];

meanxa = squeeze(mean(xa,2));
rmsa = (meanxa(1:3,:)-x(1:3,1:TCYC)).^2;
rms = sqrt(mean(mean(rmsa(:,2001:10000))))

for i=1:En
    Xa(:,:,i) = squeeze(xa(1:3,i,:)) - meanxa(1:3,:);
end
Padiag = sum(Xa.^2,3)/(En-1);
spread = sqrt(mean(mean(Padiag(:,2001:10000))))

tplot = [1:TCYC]*DT;

%rmsa = sqrt(mean((meanxa(1:n2,:)-x(:,1:TCYC)).^2));
%rms = sqrt(mean(rmsa.^2))
toc

figure(1)
for j=1:3
    subplot(3,1,j)
    hold on
    plot(tplot,meanxa(j,:),'color',grey,'linewidth',2)
    plot(tplot,x(j,1:TCYC),'k--')
        %if (j==1)
        %legend('observing u', 'observing (u,v,w)','true')
        %end

    if (j==1)
        ylabel(['u'])
    elseif(j==2)
        ylabel(['v'])
    elseif(j==3)
        ylabel(['w'])
    end        
    hold off
    xlim([900 1000])
end
%print -depsc -r100 state_etkfqrt.eps

figure(2)
subplot(4,1,1)
hold on
plot(tplot(1:10:end),meanxa(4,1:10:end),'color',grey,'linewidth',2)
plot([0:20:TCYC*DT],omega*ones(1,TCYC/20*DT+1),'k--')
hold off
ylabel('\omega')

subplot(4,1,2)
hold on
plot(tplot(1:10:end),meanxa(5,1:10:end),'color',grey,'linewidth',2)
plot([0:20:TCYC*DT],beta*ones(1,TCYC/20*DT+1),'k--')
hold off
ylim([0 2])
ylabel('\beta')

subplot(4,1,3)
hold on
plot(tplot(1:10:end),meanxa(6,1:10:end),'color',grey,'linewidth',2)
plot([0:20:TCYC*DT],gamma*ones(1,TCYC/20*DT+1),'k--')
hold off
ylabel('\gamma')

subplot(4,1,4)
hold on
plot(tplot(1:10:end),meanxa(7,1:10:end),'color',grey,'linewidth',2)
plot([0:20:TCYC*DT],a*ones(1,TCYC/20*DT+1),'k--')
hold off
orient tall
ylabel('a')
ylim([0 2])
xlabel('time')
orient tall
%print -depsc -r100 dparm_etkfqrt.eps

figure(3)
subplot(2,1,1)
hold on
plot(tplot(1:10:end),real(sqrt(ahat(1,2:10:end)/DT)),'color',grey,'linewidth',2)
plot([0:20:TCYC*DT],sqrt(alpha(1)/DT)*ones(1,TCYC/20*DT+1),'k--')
hold off
ylim([0 1])
ylabel('\sigma')

subplot(2,1,2)
hold on
plot(tplot(1:10:end),ahat(2,2:10:end),'color',grey,'linewidth',2)
plot([0:20:TCYC*DT],R*ones(1,TCYC/20*DT+1),'k--')
hold off
ylim([0 10*R])
ylabel('R')
xlabel('time')
%print -depsc -r100 sparm_etkfqrt.eps






